import os
import csv
from p_tqdm import p_map
from argparse import ArgumentParser
from copy import deepcopy

import numpy as np
import pickle

from c2_uplift_bandit import C2UpBandit, interact_c2
from c2_uplift_learner import C2UpUCB


def main(params):

    X = np.load(os.path.join(params.data_dir, 'X_subsampled.npy'))
    visit_proba = np.load(os.path.join(params.data_dir, 'visit_proba_subsampled.npy'))

    if params.use_true_data:
        filename = 'data'
        treat = np.load(os.path.join(params.data_dir, 'treat_subsampled.npy'))
        visit = np.load(os.path.join(params.data_dir, 'visit_subsampled.npy'))
    else:
        filename = 'generated'
        treat, visit = None, None

    if params.use_baseline_model:
        filename = filename + '-with_baseline'
        filename_control = os.path.join(params.model_dir, 'model_control_subsample_lg.sav')
        baseline_model = pickle.load(open(filename_control, 'rb'))
    else:
        filename = filename + '-without_baseline'
        baseline_model = None

    filename = filename + f'-{params.baseline_option}-radius_{params.radius}'
    if params.logt:
        filename = filename + '-logt'
    filename = filename + f'-epsilon_{params.epsilon}'
    filename = filename.replace('.', '_')
    filename = filename + f'-rng_{params.random_seed}'
    filename = filename + f'-runs_{params.n_runs}'
    filename = filename + f'-rounds_{params.n_rounds}'
    filename = filename + f'-samples_{params.n_samples_per_round}_{params.budget}'

    filename_uplifts = os.path.join(params.save_dir, filename + '-uplifts.npy')
    filename_rewards = os.path.join(params.save_dir, filename + '-rewards.npy')
    csv_file = os.path.join(params.save_dir, filename + '-params.csv')

    with open(csv_file, 'w') as csvfile:
        writer = csv.writer(csvfile)
        for key, value in params.__dict__.items():
            writer.writerow([key, value])

    n_treatments = 1
    vdim = 12 + 1

    # pool = Pool(cpu_count())
    args = []
    for i in range(params.random_seed, params.random_seed + params.n_runs):
        args.append(
            (X, visit_proba, treat, visit,
             n_treatments, vdim, baseline_model, params.baseline_option,
             params.radius, params.logt, params.epsilon,
             params.n_rounds, params.n_samples_per_round, params.budget,
             i, params.verbose)
        )
    results = p_map(run_para, args, num_cpus=params.n_cpus)
    uplifts = np.vstack([result[0] for result in results])
    rewards = np.vstack([result[1] for result in results])

    np.save(filename_uplifts, uplifts)
    np.save(filename_rewards, rewards)


def run(X, visit_proba, treat, visit,
        n_treatments, vdim, baseline_model, baseline_option,
        radius, logt, explore_proba,
        n_rounds, n_samples_per_round, budget,
        random_seed, verbose):

    bandit = C2UpBandit(X, visit_proba, treat, visit,
                        rng_sampling=random_seed, rng_feedback=random_seed)

    learner = C2UpUCB(n_treatments,
                      vdim,
                      baseline_model=deepcopy(baseline_model),
                      baseline_option=baseline_option,
                      radius=radius,
                      explore_proba=explore_proba,
                      explore_rng=random_seed)

    interact_c2(bandit, learner, n_rounds, n_samples_per_round, budget,
                logt=logt, verbose=verbose)

    return learner.uplifts, learner.rewards


def run_para(args):
    uplifts, rewards = run(*args)
    return uplifts, rewards


if __name__ == '__main__':

    parser = ArgumentParser()

    parser.add_argument('--data_dir', type=str, default='save/criteo/data/')
    parser.add_argument('--model_dir', type=str, default='save/criteo/model/')
    parser.add_argument('--save_dir', type=str, default='save/criteo/results/')

    parser.add_argument('--no_true_data', dest='use_true_data', action='store_false')
    parser.add_argument('--no_baseline', dest='use_baseline_model', action='store_false')
    parser.set_defaults(use_true_data=True, use_baseline_model=True)

    parser.add_argument('--baseline_option', type=str, default='UCB')
    parser.add_argument('--radius', type=float, default=1)
    parser.add_argument('--logt', dest='logt', action='store_true')
    parser.set_defaults(logt=False)

    parser.add_argument('--epsilon', type=float, default=0)

    parser.add_argument('--n_runs', type=int, default=2)
    parser.add_argument('--n_rounds', type=int, default=500)
    parser.add_argument('--n_samples_per_round', type=int, default=100)
    parser.add_argument('--budget', type=int, default=10)

    parser.add_argument('--random_seed', type=int, default=0)
    parser.add_argument('--n_cpus', type=int, default=10)
    # verbose > 0 with multiprocessing can cause problem
    parser.add_argument('--verbose', type=int, default=-1)

    params = parser.parse_args()

    if not os.path.exists(params.save_dir):
        os.makedirs(params.save_dir)

    main(params)
